{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": [],
      "gpuType": "T4",
      "authorship_tag": "ABX9TyPe2VxGkBat6KPloWI77uxs",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/greengerong/awesome-llm/blob/main/colab/green_starchat_alpha_%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# starchat-alpha微调\n",
        "\n",
        "参考：[https://huggingface.co/blog/zh/starchat-alpha](https://huggingface.co/blog/zh/starchat-alpha)"
      ],
      "metadata": {
        "id": "Sdg6uBKsaZ1Q"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W7vTtm4gaNDh"
      },
      "outputs": [],
      "source": [
        "!pip install datasets  transformers huggingface"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 数据集处理"
      ],
      "metadata": {
        "id": "QrBvbJ95aXPK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "dataset = load_dataset(\"HuggingFaceH4/oasst1_en\")\n",
        "print(dataset)\n",
        "print(\"-------------------\")\n",
        "sample = dataset[\"train_ift\"][0]\n",
        "print(sample)"
      ],
      "metadata": {
        "id": "MYRcQaxOato8"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "一个更好的方法是使用一种结构化的格式，比如 ChatML。这种格式会对每一个对话轮次进行包装。包装使用的是一些特殊的 token，用以标明询问或回答的角色。\n",
        "\n",
        "在这种格式下，我们使用这些特殊的 token:\n",
        "\n",
        "<|system|>: 表示系统信息开始的地方，这里的系统信息描述了这个聊天机器人的身份角色。\n",
        "<|user|>: 表示这里的话语是人类用户说出来的。\n",
        "<|assistant|>: 表示这里的话语是 AI 机器人说出来的。\n",
        "<|end|>: 表示说话内容的结尾，或系统信息的结尾。\n",
        "下面我们写一个函数，把我们的实例数据用这些特殊的 token 包装起来:"
      ],
      "metadata": {
        "id": "0flrCJs_bVv1"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "system_token = \"<|assistant|>\"\n",
        "user_token = \"<|user|>\"\n",
        "assistant_token = \"<|assistant|>\"\n",
        "end_token = \"<|end|>\"\n",
        "\n",
        "def prepare_dialogue(example):\n",
        "    system_msg = \"Below is a dialogue between a human and an AI assistant called StarChat.\"\n",
        "    prompt = system_token + \"\\n\" + system_msg + end_token + \"\\n\"\n",
        "    for message in example[\"messages\"]:\n",
        "        if message[\"role\"] == \"user\":\n",
        "            prompt += user_token + \"\\n\" + message[\"content\"] + end_token + \"\\n\"\n",
        "        else:\n",
        "            prompt += assistant_token + \"\\n\" + message[\"content\"] + end_token + \"\\n\"\n",
        "    return prompt\n",
        "\n",
        "print(prepare_dialogue(sample))\n"
      ],
      "metadata": {
        "id": "GYR9u5XcbXir"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "以上就是包装好后的数据！下一步，我们还需要把这些特殊的 token 加入到分词器 (tokenizer) 的词汇表中。我们这里下载 StarCoder 的分词器，然后加入这些特殊 token:"
      ],
      "metadata": {
        "id": "XzpZcN5fbnjn"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!huggingface-cli login"
      ],
      "metadata": {
        "id": "ojLACPsBb2qR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoTokenizer\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"bigcode/starcoderbase\")\n",
        "tokenizer.add_special_tokens({\"additional_special_tokens\": [\"<|system|>\", \"<|assistant|>\", \"<|user|>\", \"<|end|>\"]})\n",
        "# Check the tokens have been added\n",
        "print(tokenizer.special_tokens_map)\n",
        "print(\"-------------------\")\n",
        "print(tokenizer(\"<|assistant|>\"))"
      ],
      "metadata": {
        "id": "KtehvN5wbocu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [],
      "metadata": {
        "id": "7W95edT5cZY5"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "掩盖掉用户话语部分的标签"
      ],
      "metadata": {
        "id": "cJ8GpFVFcWtD"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def mask_user_labels(tokenizer, labels):\n",
        "    user_token_id = tokenizer.convert_tokens_to_ids(user_token)\n",
        "    assistant_token_id = tokenizer.convert_tokens_to_ids(assistant_token)\n",
        "    for idx, label_id in enumerate(labels):\n",
        "        if label_id == user_token_id:\n",
        "            current_idx = idx\n",
        "            while labels[current_idx]!= assistant_token_id and current_idx < len(labels):\n",
        "                labels[current_idx] = -100 # Ignored by the loss\n",
        "                current_idx += 1\n",
        "\n",
        "dialogue = \"<|user|>\\nHello, can you help me?<|end|>\\n<|assistant|>\\nSure, what can I do for you?<|end|>\\n\"\n",
        "input_ids = tokenizer(dialogue).input_ids\n",
        "labels = input_ids.copy()\n",
        "mask_user_labels(tokenizer, labels)\n",
        "print(labels)\n"
      ],
      "metadata": {
        "id": "gob7FycNcXOa"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# 使用 DeepSpeed ZeRO-3 微调 StarCoder"
      ],
      "metadata": {
        "id": "hnA6ZPaycguu"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!rm -rf starcoder\n",
        "!git clone https://github.com/greengerong/starcoder.git\n",
        "!pip install -r /content/starcoder/chat/requirements.txt\n",
        "!sudo apt-get install git-lfs\n",
        "!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118\n",
        "!pip install -U deepspeed"
      ],
      "metadata": {
        "id": "_i73GtKuch68"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "接下来我们就可以训练了"
      ],
      "metadata": {
        "id": "ITrFj5UUdRlK"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!torchrun  /content/starcoder/chat/train.py /content/starcoder/chat/config.yaml --deepspeed=/content/starcoder/chat/deepspeed_z3_config_bf16.json"
      ],
      "metadata": {
        "id": "BAkUpGkGdQ0t"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}